package com.gmcloud.common.security.service;

import com.gmcloud.common.core.utils.RedisUtil;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2TokenType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;

import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

/**
 * @author zl.sir
 * @version 1.0
 * @since 2022/8/18 14:12
 * redis 存储token服务
 */
@Service
public class GmRedisOAuth2AuthorizationServiceImpl implements OAuth2AuthorizationService {

    private static final Long TIMEOUT = 10L;

    private static final String AUTHORIZATION = "token";

    private static final String STATE="state";

    private final RedisUtil redisUtil;

    public GmRedisOAuth2AuthorizationServiceImpl(RedisUtil redisUtil) {
        this.redisUtil = redisUtil;
    }

    @Override
    public void save(OAuth2Authorization authorization) {
        Assert.notNull(authorization, "authorization cannot be null");

        // 带有state的参数
        if (isState(authorization)) {
            String token = authorization.getAttribute(STATE);
            redisUtil.set(buildKey(OAuth2ParameterNames.STATE, token), authorization, TIMEOUT,
                    TimeUnit.MINUTES);
        }

        if (isCode(authorization)) {
            OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization
                    .getToken(OAuth2AuthorizationCode.class);
            if (Objects.nonNull(authorizationCode)) {
                OAuth2AuthorizationCode authorizationCodeToken = authorizationCode.getToken();
                long between = ChronoUnit.MINUTES.between(Objects.requireNonNull(authorizationCodeToken.getIssuedAt()),
                        authorizationCodeToken.getExpiresAt());
                redisUtil.set(buildKey(OAuth2ParameterNames.CODE, authorizationCodeToken.getTokenValue()),
                        authorization, between, TimeUnit.MINUTES);
            }
        }
        // 存储刷新token
        OAuth2Authorization.Token<OAuth2RefreshToken> refreshTokenToken = authorization.getRefreshToken();
        if (isRefreshToken(authorization)&&Objects.nonNull(refreshTokenToken)) {
            OAuth2RefreshToken refreshToken = refreshTokenToken.getToken();
            long between = ChronoUnit.SECONDS.between(Objects.requireNonNull(refreshToken.getIssuedAt()), refreshToken.getExpiresAt());
            redisUtil.set(buildKey(OAuth2ParameterNames.REFRESH_TOKEN, refreshToken.getTokenValue()),
                    authorization, between, TimeUnit.SECONDS);
        }

        // 存储验证通过的token
        if (isAccessToken(authorization)) {
            OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
            long between = ChronoUnit.SECONDS.between(Objects.requireNonNull(accessToken.getIssuedAt()), accessToken.getExpiresAt());
            redisUtil.set(buildKey(OAuth2ParameterNames.ACCESS_TOKEN, accessToken.getTokenValue()),
                    authorization, between, TimeUnit.SECONDS);
        }

    }

    @Override
    public void remove(OAuth2Authorization authorization) {
        Assert.notNull(authorization, "authorization cannot be null");

        List<String> keys = new ArrayList<>();
        if (isState(authorization)) {
            String token = authorization.getAttribute(STATE);
            keys.add(buildKey(OAuth2ParameterNames.STATE, token));
        }

        if (isCode(authorization)) {
            OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization
                    .getToken(OAuth2AuthorizationCode.class);
            if (Objects.nonNull(authorizationCode)) {
                OAuth2AuthorizationCode authorizationCodeToken = authorizationCode.getToken();
                keys.add(buildKey(OAuth2ParameterNames.CODE, authorizationCodeToken.getTokenValue()));
            }
        }
        OAuth2Authorization.Token<OAuth2RefreshToken> refreshTokenToken = authorization.getRefreshToken();
        if (isRefreshToken(authorization) && Objects.nonNull(refreshTokenToken)) {
            OAuth2RefreshToken refreshToken = refreshTokenToken.getToken();
            keys.add(buildKey(OAuth2ParameterNames.REFRESH_TOKEN, refreshToken.getTokenValue()));
        }

        if (isAccessToken(authorization)) {
            OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
            keys.add(buildKey(OAuth2ParameterNames.ACCESS_TOKEN, accessToken.getTokenValue()));
        }
        redisUtil.del(keys);

    }

    @Override
    @Nullable
    public OAuth2Authorization findById(String id) {
        throw new UnsupportedOperationException();
    }

    @Nullable
    @Override
    public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType) {
        Assert.hasText(token, "token cannot be empty");
        Assert.notNull(tokenType, "tokenType cannot be empty");

        return (OAuth2Authorization) redisUtil.get(buildKey(tokenType.getValue(), token));
    }

    private String buildKey(String type, String id) {
        return String.format("%s::%s::%s", AUTHORIZATION, type, id);
    }

    private static boolean isState(OAuth2Authorization authorization) {
        return Objects.nonNull(authorization.getAttribute(STATE));
    }

    private static boolean isCode(OAuth2Authorization authorization) {
        OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization
                .getToken(OAuth2AuthorizationCode.class);
        return Objects.nonNull(authorizationCode);
    }

    private static boolean isRefreshToken(OAuth2Authorization authorization) {
        return Objects.nonNull(authorization.getRefreshToken());
    }

    private static boolean isAccessToken(OAuth2Authorization authorization) {
        return Objects.nonNull(authorization.getAccessToken());
    }
}
